#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar  7 13:32:19 2022

Simulation of sliced 2-Wasserstein distance
"""

import ot
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def generate_uniform_sphere(d,n,R):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,1,size=(1,d))
        data[j] = R*temp/np.linalg.norm(temp)
    return data

#compute the covariance of the limiting Gaussian distribution  
def cov_simulate(d,m):
    cov = 0
    angles = generate_uniform_sphere(d,m,1)
    for i in range(m):
        u = angles[i,:]
        for j in range(m):
            v = angles[j,:]
            cov += ((8*sum(u)*sum(v)*np.dot(u,v))/3)/m**2
    return cov
    

R = 1
d = 3
rswd = 1
vaS = cov_simulate(d,1000)

sample_sizes = [50,100,500]

m = 1
xs = np.linspace(-5,5,500)
limSdens = np.exp(-xs**2/(2*vaS))/np.sqrt(2*vaS*np.pi)

n_seed = 10
for n in sample_sizes:
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    swd = np.empty((500,))
    for i in range(500):
        datap = generate_uniform_sphere(d,n,R)+1
        dataq = generate_uniform_sphere(d,n,R)
        smp = np.empty((n_seed,))
        for seed in range(n_seed):
            smp[seed] = ot.sliced_wasserstein_distance(datap, dataq, a, b, 1000, seed=seed)**2
        swd[i] = np.mean(smp)
    swd_mean = np.mean(swd)
    swd = np.sqrt(n)*(swd - rswd)     
    swd_std = np.std(swd)
    density = gaussian_kde(swd,'silverman')
    plt.figure(m)
    plt.plot(xs,density(xs),color='cadetblue')
    plt.fill_between(xs, density(xs),color='paleturquoise',alpha=0.5)
    plt.plot(xs,limSdens,color='palevioletred')
    plt.fill_between(xs,limSdens,color='pink',alpha=0.5)
    plt.xlabel("x")
    plt.ylabel("Density")
    plt.title('sample size n = '+str(n))
    m += 1
    
        
        
